import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy import linalg

# ----------------------helper functions------------------


def get_index(val, ar):
    """
    find index where ar[index] is closest to val
    """""
    return np.argmin(np.abs(np.array(ar) - val))


def get_subset(i, win=1, ar=None):
    """"
    i: index around which entries are to be taken
    win: window for linear fit
    length: length of ordered data array
    array: array out of which elements [i-win:i+win] are to taken modulo edge effects
    """""
    length = len(ar)
    if i-win < 0:
        start = 0
        end = 2*win
    elif i+win > length - 1:
        start = i - win - (i + win - length + 1)
        end = length - 1
    else:
        start = i - win
        end = i+win

    return ar[start:end+1]


def window_linear_interp(x, ar_x, ar_y, ar_xerr = None, ar_yerr= None, xerr = None, win = 2):
    """get the actual interpolating function value at x
    also returns the func value with error, based on the fit
    x: x value where interpolation should be evaluated
    xerr: error on x value where interpolation should be evaluated
    ar_x: array of x values used for interpolation. same shape as ar_x 
    ar_y: array of y values used for interpolation. same shape as ar_x
    ar_yerr: array of y errors used for interpolation. same shape as ar_x
    ar_xerr: array of x errors used for interpolation. same shape as ar_x
    """""
    i = get_index(ar_x, x)
    fit_x = np.array(get_subset(i, win=win, ar=ar_x))
    fit_y = np.array(get_subset(i, win=win, ar=ar_y))

    def ffunc(x, a, b):
        return a * x + b


    if ar_yerr is not None:
        fit_yerr = np.array(get_subset(i, win=win, ar=ar_yerr))
        popt, pcov = curve_fit(ffunc, fit_x, fit_y, sigma=fit_yerr, absolute_sigma=True)
        perr = np.sqrt(np.diag(pcov))

        # http: // physics.princeton.edu / ~mcdonald / examples / statistics / madansky_jassa_54_173_59.pdf
        # This recursive algorithm is necessary if you have error bars on X and Y data
        if ar_xerr is not None:
            delta = 1
            while np.abs(delta) > 0.01:
                popt_old = popt
                fit_xerr = get_subset(i, win=win, ar=ar_xerr)
                sigma_arr = np.sqrt(fit_yerr**2 + popt_old[0]**2* np.array(fit_xerr)**2)
                popt, pcov = curve_fit(ffunc, fit_x, fit_y, sigma=sigma_arr, absolute_sigma=True)
                perr = np.sqrt(np.diag(pcov))
                delta = np.abs(popt[0]- popt_old[0])

    else:
        popt, pcov = curve_fit(ffunc, np.array(fit_x), np.array(fit_y))
        perr = np.sqrt(np.diag(pcov))

    Hess = 2 * linalg.inv(pcov)

    def get_err(Hess, p):
        '''local change of chi^2 around optimum. p is array of fitpars'''
        return 0.5*p.dot(Hess.dot(p.T))

    # ----------within fit ellipse, predict range of possible y values of fit function at x with error xerr------------

    def get_ffunc_vals(x, Hess, perr, popt, ffunc):
        v = []
        for p0 in np.linspace(-perr[0], perr[0], 50):
            for p1 in np.linspace(-perr[1], perr[1], 50):
                p_rel = np.array([p0, p1])
                p_abs = popt + p_rel
                if get_err(Hess, p_rel) < 1.:
                    v.append(ffunc(x, *p_abs))
        return v

    if xerr is not None:
        ffunc_vals = np.ndarray.flatten(np.array([get_ffunc_vals(x, Hess, perr, popt, ffunc),
                      get_ffunc_vals(x-xerr, Hess, perr, popt, ffunc),
                      get_ffunc_vals(x+xerr, Hess, perr, popt, ffunc)]))
    else:
        ffunc_vals = np.array(get_ffunc_vals(x, Hess, perr, popt, ffunc))

    ffunc_err = np.max(np.abs(ffunc_vals - ffunc(x, *popt)))

    return ffunc(x, *popt), ffunc_err


# ----------------------------------------load data-----------------------------------------

dopings = np.arange(0, 35, 2)
dopings = [str(el) for el in dopings]
str_len_mz = []
str_len_c1 = []
str_len_mz_interp = []
str_len_mz_interp_err = []
str_len_c1_interp = []
str_len_c1_interp_err = []
crit_mz = []
crit_c1 = []

dir = "V:\\Paper Data\AFM_StringTheory\FCS\\"
fn_d = "cold_avgs.pkl"
print "Reading doping data from file: ", fn_d
with open(dir+fn_d, 'rb') as f:
    dat_d = pickle.load(f)
dopings_d = [d/100. for d in dat_d['doping']]
dopingerrs_d = [de/100. for de in dat_d['dopingerr']]

mzs_d = [np.sqrt(np.mean(np.array(arr[0])**2)) for arr in dat_d['mzs']]
mzse_d = [0.5 * (np.std(np.array(arr)**2) / mz) / np.sqrt(np.size(np.array(arr))) for arr, mz in zip(dat_d['mzs'], mzs_d)]
c1s_d = [cf[0] for cf in dat_d['cfs']]
c1es_d = [cfe[0] for cfe in dat_d['cfes']]
scs_d = dat_d['sc']
sces_d = dat_d['sces']
asls_d = dat_d['asl']
asles_d = dat_d['aslerr']

dir = "V:\\Paper Data\AFM_StringTheory\\FCS\\"
fn_AS_d = "gst0.60J_cold_avgs.pkl"
print "Reading doping data from file: ", fn_AS_d
with open(dir+fn_AS_d, 'rb') as f:
    dat_AS_d = pickle.load(f)

dopings_AS_d = dat_AS_d['doping']
scs_AS_d = dat_AS_d['sc']
sces_AS_d = dat_AS_d['sces']
asls_AS_d = dat_AS_d['asl']
asles_AS_d = dat_AS_d['aslerr']


#----------------------make interpolations------------------

for d_str in dopings_AS_d:
    dir = "V:\\Paper Data\AFM_StringTheory\FCS\\"
    fn_ts = "ts{}_asl.pkl".format(d_str)
    print "Reading doping data from file: ", fn_ts
    with open(dir + fn_ts, 'rb') as f:
        dat_ts = pickle.load(f)

    d = float(d_str)/100.
    mzs = dat_ts['mzs_ts']
    mzes = dat_ts['mzse_ts']
    c1s = dat_ts['c1s_ts']
    c1es = dat_ts['c1es_ts']
    asls = dat_ts['asls_ts']
    asles = dat_ts['asles_ts']

    # with error bars
    val_mz, val_mz_err = window_linear_interp(d, dopings_d, mzs_d, ar_yerr=mzse_d, win=3)
    val_c1, val_c1_err = window_linear_interp(d, dopings_d, c1s_d, ar_yerr=c1es_d, win=3)


    print '  Using mZ for effective T:'
    print '   Doping:', d, ' Mz:', val_mz, ' Mz_err:', val_mz_err

    print '  Using c1 for effective T:'
    print '   Doping: ', d, ' C1:', val_c1, ' C1_err:', val_c1_err

    # with interpolation for output - pure linear

    # with error bars
    v1, v2 = window_linear_interp(val_mz, mzs, asls, ar_xerr= mzes, ar_yerr=asles, xerr=val_mz_err, win=15)

    str_len_mz_interp.append(v1)
    str_len_mz_interp_err.append(v2)
    print 'Error:', v2

    # with error bars
    w1, w2 = window_linear_interp(val_c1, c1s, asls, ar_xerr=c1es, ar_yerr=asles, xerr=val_c1_err, win=15)

    str_len_c1_interp.append(w1)
    str_len_c1_interp_err.append(w2)
    print 'Error:', w2

    crit_mz.append(val_mz)
    crit_c1.append(val_c1)


#----------------------check doping data interpol------------------

# Checking interpolation Mz vs. doping in doping data:
xs = np.linspace(0.0, 0.3, 100)
plt.errorbar(dopings_d, mzs_d, yerr=mzse_d, fmt='o', label='temperature')

# with error bars
plt.errorbar(xs, [window_linear_interp(x, dopings_d, mzs_d, ar_yerr= mzse_d, win = 3)[0] for x in xs],
             yerr=[window_linear_interp(x, dopings_d, mzs_d, ar_yerr= mzse_d, win = 3)[1] for x in xs], fmt='-', label='temperature')

plt.title("interpolation check")
plt.xlabel("doping")
plt.ylabel("m_z")
plt.savefig('doping_mz.pdf', transparent=True)
plt.show()
plt.close()

# Checking interpolation C1 vs. doping in doping data:
xs = np.linspace(0.0, 0.3, 100)
plt.errorbar(dopings_d, c1s_d, yerr=c1es_d, fmt='o', label='temperature')

# with error bars
plt.errorbar(xs, [window_linear_interp(x, dopings_d, c1s_d, ar_yerr= c1es_d, win = 3)[0] for x in xs],
             yerr=[window_linear_interp(x, dopings_d, c1s_d, ar_yerr= c1es_d, win = 3)[1] for x in xs], fmt='-', label='temperature')

plt.title("interpolation check")
plt.xlabel("doping")
plt.ylabel("C_s(1)")
plt.savefig('doping_c1.pdf', transparent=True)
plt.show()
plt.close()


#-------------------------------------figure plotting-------------------------------------

fig = plt.figure(figsize=[24*0.8,8*0.8])
ax_1 = fig.add_axes([0.03, 0.15, 0.3, 0.8])
ax_2 = fig.add_axes([0.36, 0.15, 0.3, 0.8])
ax_3 = fig.add_axes([0.69, 0.15, 0.3, 0.8])

# add +1 to string length, because length is given by number of sites!

ax_1.errorbar(dopings_d, np.array(asls_d)+1, yerr=asles_d, fmt='o', label='doping data')
ax_1.errorbar([float(d)/100. for d in dopings], np.array(str_len_mz_interp)+1, yerr=str_len_mz_interp_err, fmt='go', label='effective T via Mz + doping holes + interpol')
ax_1.errorbar([float(d)/100. for d in dopings], np.array(str_len_c1_interp)+1, yerr=str_len_c1_interp_err, fmt='ro', label='effective T via C1 + doping holes + interpol')
ax_1.errorbar(np.array(dopings_AS_d)/100., np.array(asls_AS_d)+1, asles_AS_d)

ax_1.set_xlabel("doping")
ax_1.set_ylabel("average string length")

ax_2.errorbar(mzs_d, np.array(asls_d)+1, yerr=None, fmt='o', label='doping data')
ax_2.errorbar(crit_mz, np.array(str_len_mz_interp)+1, yerr=None, fmt='go', label='effective T via Mz + doping holes + interpol')
ax_2.set_xlabel("staggered magnetization")

ax_3.errorbar(c1s_d, np.array(asls_d)+1, yerr=None, fmt='o', label='doping data')
ax_3.errorbar(crit_c1, np.array(str_len_c1_interp)+1, yerr=None, fmt='ro', label='effective T via C1 + doping holes + interpol')
ax_3.set_xlabel("NN spincorr C1")
ax_3.set_ylabel("average string length")

ax_1.legend(loc=1)
ax_2.legend(loc=1)
ax_3.legend(loc=1)


#-------------------------------------output-------------------------------------


dir = "V:\\Paper Data\AFM_StringTheory\FCS\\"

dic = {'x': [float(d)/100. for d in dopings],
        'y': np.array(str_len_mz_interp),
        'yerr': str_len_mz_interp_err,
        'x_theo': np.array(dopings_AS_d)/100.,
        'y_theo': np.array(asls_AS_d),
        'yerr_theo': asles_AS_d,
        'x_exp': dopings_d,
        'y_exp': asls_d,
        'yerr_exp': asles_d,
        'y2_exp': str_len_c1_interp,
        'y2err_exp': str_len_c1_interp_err,
        'doping': dopings_d,
        'dopingerr': dopingerrs_d
        }

file_n = 'pattern_group2.pkl'

print "Writing doping data to file: ", 'pattern'
with open(dir + file_n, 'wb') as f:
    pickle.dump(dic, f)

plt.savefig('pattern.pdf', transparent=True)
plt.show()
